import pandas as pd
import json
import traceback
import sympy as sp
from check_answer_rowwise import (
    cached_parsing, extract_latex_answer, latex_to_sympy,
    replace_infinite_sums, _meta_compare, _compare_symbolic_wrapper, 
    _compare_numeric_wrapper
)

from collect_llm_answers import ask_model

CHALLENGES_FILE = 'sample_data/asymob_challenges_sample.csv'
# Substitutions are generated using the `create_subs.py` script.
SUBS_FILE = 'sample_data/sample_subs.csv'

# Using a small number of models to show functionality.
# Make sure you've created an `api_keys.json` file with your API keys.
# Use the `api_keys_template.json` file as a template.
MODELS_LIST = [
    # HuggingFace models:
    # ('DeepSeek-Prover-V2-671B', None),
    # ('DeepSeek-R1', None),
    # ('DeepSeek-V3', None),
    # ('meta-llama/Llama-4-Scout-17B-16E-Instruct', None),
    # ('nvidia/Llama-3_3-Nemotron-Super-49B-v1', None),
    # ('Qwen/Qwen2.5-72B-Instruct', None),

    # Google models:
    ('gemini/gemini-2.0-flash', False),
    ('gemini/gemini-2.0-flash', True),
    # ('gemini/gemini-2.5-flash-preview-04-17', False),
    # ('gemini/gemini-2.5-flash-preview-04-17', True),
    # ('gemini/gemini-2.5-flash', True),
    # ('gemini/gemini-2.5-flash', False),
    # ('gemini/gemma-3n-e4b-it', None),

    # OpenAI models:
    ('gpt-4.1', False),
    ('gpt-4.1', True),
    # ('gpt-4o', False),
    # ('gpt-4o', True),
    # ('gpt-4o-mini', False),
    # ('gpt-4o-mini', True),
    # ('o4-mini', False),
    # ('o4-mini', True),
]

def collect_answers(challenges_file=CHALLENGES_FILE):
    challenges = pd.read_csv(challenges_file)
    results = []
    for _, challenge in challenges.iterrows():
        challenge = challenge.to_dict()
        print(f"Challenge ID: {challenge['challenge_id']}")
        for model_name, use_code in MODELS_LIST:
            print(f"  Model: {model_name}, use_code: {use_code}")
            answer = ask_model(
                model_name=model_name,
                question_text=challenge['challenge'],
                code_execution=use_code
            )
            result = {
                'model_name': model_name,
                'use_code': use_code,
                **challenge,
                **answer
            }
            results.append(result)

    return results

def load_subs(subs_file=SUBS_FILE):
    subs_df = pd.read_csv(subs_file)
    subs_df['subs_vals'] = subs_df['subs_json'].apply(lambda x: json.loads(x))
    subs_df['numerical_value'] = subs_df['numerical_value'].apply(str)
    subs_df.drop('subs_json', axis=1, inplace=True)
    subs_df.set_index('challenge_id', inplace=True)

    return subs_df


def check_answer(question_data, numeric_subs, recheck_errors=False):
    """
    Question data - a json containing the dataframe's row. 
    """
    try:
        # the true answer is already in "clean" form, so we don't need to 
        # work hard for it.
        question_data['true_answer'] = cached_parsing(
            question_data['true_answer'])
        
        question_data['final_answer_latex'] = extract_latex_answer(
            question_data['full_answer'])
        
        model_answer, answer_type = latex_to_sympy(
            question_data['final_answer_latex']
        )
        model_answer = model_answer.expand().removeO()
        if model_answer.has(sp.Sum):
            model_answer = replace_infinite_sums(model_answer)


        question_data['model_answer_sympy'] = model_answer
        question_data['latex_parsing_method'] = answer_type
        
        should_check = _meta_compare(
            question_data['model_answer_sympy'], 
            question_data['true_answer']
        )         
        if should_check:
            question_data = _compare_symbolic_wrapper(question_data)
            question_data = _compare_numeric_wrapper(question_data, numeric_subs)
        else:
            question_data['symbolic_correct'] = False
            question_data['numeric_correct'] = False

    except Exception as e:
        ex = traceback.format_exc()
        print(f'Error parsing {(question_data)}')
        print(f'Error: {ex}')
        question_data['symbolic_comparison_error'] = 'Joined error:\n' + str(ex)
        question_data['numeric_comparison_error'] = 'Joined error:\n' + str(ex)
    for key in [
        'numeric_correct', 'strict_mode', 
        'latex_parsing_method', 'model_answer_sympy', 
        'symbolic_correct']:
        if key not in question_data:
            question_data[key] = None
    
    return question_data

def validate_answers(answers, subs_df):
    checked_answers = []
    for answer in answers:
        answer['true_answer'] = answer['answer_sympy']
        q_id = answer['challenge_id']
        numeric_subs = subs_df.loc[q_id].values.tolist()
        checked_answer = check_answer(
            dict(answer), 
            numeric_subs, 
            recheck_errors=False)
        checked_answers.append(checked_answer)
    return checked_answers

def main():
    # Part 1 - collect all answers
    all_results = collect_answers()

    # Part 2- validate
    subs_df = load_subs()
    validated_results = validate_answers(all_results, subs_df)
    df = pd.DataFrame.from_records(validated_results)
    df = df[['challenge_id', 'model_name', 'use_code', 'full_answer',
        'final_answer_latex', 'true_answer', 'symbolic_correct', 
        'numeric_correct']]
    df.to_csv('sample_results.csv', index=False)
    # Process all_results as needed
if __name__ == "__main__":
    main()